{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Othello_GPT.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a demo notebook porting the weights of the Othello-GPT Model from the excellent [Emergent World Representations](https://arxiv.org/pdf/2210.13382.pdf) paper to my TransformerLens library. Check out the paper's [blog post](https://thegradient.pub/othello/), [paper](https://arxiv.org/pdf/2210.13382.pdf), and [github](https://github.com/likenneth/othello_world/)\n",
    "\n",
    "I think this is a super interesting paper, and I want to better enable work trying to reverse-engineer this model! I'm particularly curious about: \n",
    "* Why non-linear probes work much better than linear probes? \n",
    "    * Is the model internally representing the board in a usable yet non-linear way? \n",
    "    * Is there a representation of simpler concepts (eg diagonal lines in the board, number of black pieces, whether a cell is blank)) that the non-linear probe uses to compute board positions, but where the model internally reasons in this simpler representation?\n",
    "* What's going up with the model editing? \n",
    "    * The paper edits across many layers at once. What's the minimal edit that works? \n",
    "        * Can we edit just before the final layer? \n",
    "        * Can we do a single edit rather than across many layers?\n",
    "    * If we contrast model activations pre and post edit, what changes? \n",
    "        * Which components shift their output and how does this affect the logits? \n",
    "        * Is there significant depth of composition, or does it just affect the output logits?\n",
    "* Can we find any non-trivial circuits in the model?\n",
    "    * Start with [exploratory techniques](https://neelnanda.io/exploratory-analysis-demo), like direct logit attribution, or just looking at head attention patterns, and try to get traction\n",
    "    * Pick a simple sub-task, eg figuring out whether a cell is blank, and try to interpret that.\n",
    "\n",
    "\n",
    "I uploaded pre-converted checkpoints to HuggingFace, which can be automatically downloaded, and there's a code snippet to do this after the setup.\n",
    "\n",
    "If you want to use the author's code, I wrote a script to load and convert checkpoints from the author's code, given below this.\n",
    "\n",
    "To get started, check out the transformer lens [main tutorial](https://neelnanda.io/transformer-lens-demo) and [tutorial on exploratory techniques](https://neelnanda.io/exploratory-analysis-demo), and the author's [excellent Github](https://github.com/likenneth/othello_world/) (Ot**hello world**) for various notebooks demonstrating their code, showing how to load inputs, etc. And check out my [concrete open problems in mechanistic interpretability](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj) sequence, especially the algorithmic problems post, for tips on this style of research."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Setup (Skip)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running as a Jupyter notebook - intended for development only!\n",
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_72266/3797162913.py:25: DeprecationWarning:\n",
      "\n",
      "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "\n",
      "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_72266/3797162913.py:26: DeprecationWarning:\n",
      "\n",
      "`magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: transformer_lens in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (0.0.0)\n",
      "Requirement already satisfied: accelerate>=0.23.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.29.1)\n",
      "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.14.1)\n",
      "Requirement already satisfied: better-abc<0.0.4,>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n",
      "Requirement already satisfied: datasets>=2.7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.18.0)\n",
      "Requirement already satisfied: einops>=0.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.7.0)\n",
      "Requirement already satisfied: fancy-einsum>=0.0.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.0.3)\n",
      "Requirement already satisfied: jaxtyping>=0.2.11 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.19)\n",
      "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (1.26.4)\n",
      "Requirement already satisfied: pandas>=1.1.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.0.3)\n",
      "Requirement already satisfied: rich>=12.6.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (13.7.1)\n",
      "Requirement already satisfied: sentencepiece in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.2.0)\n",
      "Requirement already satisfied: torch!=2.0,!=2.1.0,>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (2.1.2)\n",
      "Requirement already satisfied: tqdm>=4.64.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.66.2)\n",
      "Requirement already satisfied: transformers>=4.37.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.39.3)\n",
      "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (4.11.0)\n",
      "Requirement already satisfied: wandb>=0.13.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformer_lens) (0.16.6)\n",
      "Requirement already satisfied: packaging>=20.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (24.0)\n",
      "Requirement already satisfied: psutil in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (5.9.8)\n",
      "Requirement already satisfied: pyyaml in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (6.0.1)\n",
      "Requirement already satisfied: huggingface-hub in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.22.2)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from accelerate>=0.23.0->transformer_lens) (0.4.2)\n",
      "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.13.3)\n",
      "Requirement already satisfied: pyarrow>=12.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (15.0.2)\n",
      "Requirement already satisfied: pyarrow-hotfix in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.6)\n",
      "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.3.8)\n",
      "Requirement already satisfied: requests>=2.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (2.31.0)\n",
      "Requirement already satisfied: xxhash in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.4.1)\n",
      "Requirement already satisfied: multiprocess in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (0.70.16)\n",
      "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets>=2.7.1->transformer_lens) (2024.2.0)\n",
      "Requirement already satisfied: aiohttp in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from datasets>=2.7.1->transformer_lens) (3.9.3)\n",
      "Requirement already satisfied: typeguard>=2.13.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jaxtyping>=0.2.11->transformer_lens) (4.2.1)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from pandas>=1.1.5->transformer_lens) (2024.1)\n",
      "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.2.0)\n",
      "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from rich>=12.6.0->transformer_lens) (2.17.2)\n",
      "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.12)\n",
      "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1)\n",
      "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (3.1.3)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (2023.12.25)\n",
      "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from transformers>=4.37.2->transformer_lens) (0.15.2)\n",
      "Requirement already satisfied: Click!=8.0.0,>=7.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (8.1.7)\n",
      "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (3.1.43)\n",
      "Requirement already satisfied: sentry-sdk>=1.0.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.44.1)\n",
      "Requirement already satisfied: docker-pycreds>=0.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n",
      "Requirement already satisfied: setproctitle in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.3.3)\n",
      "Requirement already satisfied: setuptools in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (69.2.0)\n",
      "Requirement already satisfied: appdirs>=1.4.3 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n",
      "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from wandb>=0.13.5->transformer_lens) (4.25.3)\n",
      "Requirement already satisfied: six>=1.4.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.16.0)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (23.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.4.1)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.5)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.9.4)\n",
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.11)\n",
      "Requirement already satisfied: mdurl~=0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich>=12.6.0->transformer_lens) (0.1.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (3.6)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.2.1)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2024.2.2)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (2.1.5)\n",
      "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch!=2.0,!=2.1.0,>=1.10->transformer_lens) (1.3.0)\n",
      "Requirement already satisfied: smmap<6,>=3.0.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.1)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n",
      "Requirement already satisfied: circuitsvis in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (1.43.2)\n",
      "Requirement already satisfied: importlib-metadata>=5.1.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (7.1.0)\n",
      "Requirement already satisfied: numpy>=1.24 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (1.26.4)\n",
      "Requirement already satisfied: torch>=1.10 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from circuitsvis) (2.1.2)\n",
      "Requirement already satisfied: zipp>=0.5 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from importlib-metadata>=5.1.0->circuitsvis) (3.18.1)\n",
      "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.13.3)\n",
      "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (4.11.0)\n",
      "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (1.12)\n",
      "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1)\n",
      "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.10->circuitsvis) (2024.2.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.5)\n",
      "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n",
      "Collecting torchtyping\n",
      "  Using cached torchtyping-0.1.4-py3-none-any.whl.metadata (9.2 kB)\n",
      "Requirement already satisfied: torch>=1.7.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torchtyping) (2.1.2)\n",
      "Requirement already satisfied: typeguard>=2.11.1 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torchtyping) (4.2.1)\n",
      "Requirement already satisfied: filelock in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (3.13.3)\n",
      "Requirement already satisfied: typing-extensions in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (4.11.0)\n",
      "Requirement already satisfied: sympy in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (1.12)\n",
      "Requirement already satisfied: networkx in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (3.1)\n",
      "Requirement already satisfied: jinja2 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from torch>=1.7.0->torchtyping) (2024.2.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from jinja2->torch>=1.7.0->torchtyping) (2.1.5)\n",
      "Requirement already satisfied: mpmath>=0.19 in /Users/bryce/Projects/Lingwave/TransformerLens/.venv/lib/python3.11/site-packages (from sympy->torch>=1.7.0->torchtyping) (1.3.0)\n",
      "Using cached torchtyping-0.1.4-py3-none-any.whl (17 kB)\n",
      "Installing collected packages: torchtyping\n",
      "Successfully installed torchtyping-0.1.4\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "import os\n",
    "\n",
    "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
    "DEVELOPMENT_MODE = False\n",
    "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
    "try:\n",
    "    import google.colab\n",
    "\n",
    "    IN_COLAB = True\n",
    "    print(\"Running as a Colab notebook\")\n",
    "\n",
    "    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
    "    # # Install another version of node that makes PySvelte work way faster\n",
    "    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
    "    # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
    "except:\n",
    "    IN_COLAB = False\n",
    "    print(\"Running as a Jupyter notebook - intended for development only!\")\n",
    "    from IPython import get_ipython\n",
    "\n",
    "    ipython = get_ipython()\n",
    "    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
    "    ipython.magic(\"load_ext autoreload\")\n",
    "    ipython.magic(\"autoreload 2\")\n",
    "\n",
    "if IN_COLAB or IN_GITHUB:\n",
    "    %pip install transformer_lens\n",
    "    %pip install circuitsvis\n",
    "    %pip install torchtyping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using renderer: colab\n"
     ]
    }
   ],
   "source": [
    "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
    "import plotly.io as pio\n",
    "\n",
    "if IN_COLAB or not DEVELOPMENT_MODE:\n",
    "    pio.renderers.default = \"colab\"\n",
    "else:\n",
    "    pio.renderers.default = \"notebook_connected\"\n",
    "print(f\"Using renderer: {pio.renderers.default}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div id=\"circuits-vis-7eb230d9-d14a\" style=\"margin: 15px 0;\"/>\n",
       "    <script crossorigin type=\"module\">\n",
       "    import { render, Hello } from \"https://unpkg.com/circuitsvis@1.43.2/dist/cdn/esm.js\";\n",
       "    render(\n",
       "      \"circuits-vis-7eb230d9-d14a\",\n",
       "      Hello,\n",
       "      {\"name\": \"Neel\"}\n",
       "    )\n",
       "    </script>"
      ],
      "text/plain": [
       "<circuitsvis.utils.render.RenderedHTML at 0x13f282a90>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import circuitsvis as cv\n",
    "\n",
    "# Testing that the library works\n",
    "cv.examples.hello(\"Neel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import stuff\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import einops\n",
    "from fancy_einsum import einsum\n",
    "import tqdm.auto as tqdm\n",
    "import random\n",
    "from pathlib import Path\n",
    "import plotly.express as px\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchtyping import TensorType as TT\n",
    "from typing import List, Union, Optional\n",
    "from functools import partial\n",
    "import copy\n",
    "\n",
    "import itertools\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "import dataclasses\n",
    "import datasets\n",
    "from IPython.display import HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformer_lens\n",
    "import transformer_lens.utils as utils\n",
    "from transformer_lens.hook_points import (\n",
    "    HookedRootModule,\n",
    "    HookPoint,\n",
    ")  # Hooking utilities\n",
    "from transformer_lens import (\n",
    "    HookedTransformer,\n",
    "    HookedTransformerConfig,\n",
    "    FactoredMatrix,\n",
    "    ActivationCache,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x2b63dad90>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Plotting helper functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
    "    px.imshow(\n",
    "        utils.to_numpy(tensor),\n",
    "        color_continuous_midpoint=0.0,\n",
    "        color_continuous_scale=\"RdBu\",\n",
    "        labels={\"x\": xaxis, \"y\": yaxis},\n",
    "        **kwargs,\n",
    "    ).show(renderer)\n",
    "\n",
    "\n",
    "def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
    "    px.line(utils.to_numpy(tensor), labels={\"x\": xaxis, \"y\": yaxis}, **kwargs).show(\n",
    "        renderer\n",
    "    )\n",
    "\n",
    "\n",
    "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n",
    "    x = utils.to_numpy(x)\n",
    "    y = utils.to_numpy(y)\n",
    "    px.scatter(\n",
    "        y=y, x=x, labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis}, **kwargs\n",
    "    ).show(renderer)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Othello GPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "LOAD_AND_CONVERT_CHECKPOINT = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformer_lens.utils as utils\n",
    "\n",
    "cfg = HookedTransformerConfig(\n",
    "    n_layers=8,\n",
    "    d_model=512,\n",
    "    d_head=64,\n",
    "    n_heads=8,\n",
    "    d_mlp=2048,\n",
    "    d_vocab=61,\n",
    "    n_ctx=59,\n",
    "    act_fn=\"gelu\",\n",
    "    normalization_type=\"LNPre\",\n",
    ")\n",
    "model = HookedTransformer(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "566376c7d21e4645958b59f758d2a971",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "./synthetic_model.pth:   0%|          | 0.00/101M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# NBVAL_IGNORE_OUTPUT\n",
    "sd = utils.download_file_from_hf(\n",
    "    \"NeelNanda/Othello-GPT-Transformer-Lens\", \"synthetic_model.pth\"\n",
    ")\n",
    "# champion_ship_sd = utils.download_file_from_hf(\"NeelNanda/Othello-GPT-Transformer-Lens\", \"championship_model.pth\")\n",
    "model.load_state_dict(sd)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Code to load and convert one of the author's checkpoints to TransformerLens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_to_transformer_lens_format(in_sd, n_layers=8, n_heads=8):\n",
    "    out_sd = {}\n",
    "    out_sd[\"pos_embed.W_pos\"] = in_sd[\"pos_emb\"].squeeze(0)\n",
    "    out_sd[\"embed.W_E\"] = in_sd[\"tok_emb.weight\"]\n",
    "\n",
    "    out_sd[\"ln_final.w\"] = in_sd[\"ln_f.weight\"]\n",
    "    out_sd[\"ln_final.b\"] = in_sd[\"ln_f.bias\"]\n",
    "    out_sd[\"unembed.W_U\"] = in_sd[\"head.weight\"].T\n",
    "\n",
    "    for layer in range(n_layers):\n",
    "        out_sd[f\"blocks.{layer}.ln1.w\"] = in_sd[f\"blocks.{layer}.ln1.weight\"]\n",
    "        out_sd[f\"blocks.{layer}.ln1.b\"] = in_sd[f\"blocks.{layer}.ln1.bias\"]\n",
    "        out_sd[f\"blocks.{layer}.ln2.w\"] = in_sd[f\"blocks.{layer}.ln2.weight\"]\n",
    "        out_sd[f\"blocks.{layer}.ln2.b\"] = in_sd[f\"blocks.{layer}.ln2.bias\"]\n",
    "\n",
    "        out_sd[f\"blocks.{layer}.attn.W_Q\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.query.weight\"],\n",
    "            \"(head d_head) d_model -> head d_model d_head\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.b_Q\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.query.bias\"],\n",
    "            \"(head d_head) -> head d_head\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.W_K\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.key.weight\"],\n",
    "            \"(head d_head) d_model -> head d_model d_head\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.b_K\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.key.bias\"],\n",
    "            \"(head d_head) -> head d_head\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.W_V\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.value.weight\"],\n",
    "            \"(head d_head) d_model -> head d_model d_head\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.b_V\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.value.bias\"],\n",
    "            \"(head d_head) -> head d_head\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.W_O\"] = einops.rearrange(\n",
    "            in_sd[f\"blocks.{layer}.attn.proj.weight\"],\n",
    "            \"d_model (head d_head) -> head d_head d_model\",\n",
    "            head=n_heads,\n",
    "        )\n",
    "        out_sd[f\"blocks.{layer}.attn.b_O\"] = in_sd[f\"blocks.{layer}.attn.proj.bias\"]\n",
    "\n",
    "        out_sd[f\"blocks.{layer}.mlp.b_in\"] = in_sd[f\"blocks.{layer}.mlp.0.bias\"]\n",
    "        out_sd[f\"blocks.{layer}.mlp.W_in\"] = in_sd[f\"blocks.{layer}.mlp.0.weight\"].T\n",
    "        out_sd[f\"blocks.{layer}.mlp.b_out\"] = in_sd[f\"blocks.{layer}.mlp.2.bias\"]\n",
    "        out_sd[f\"blocks.{layer}.mlp.W_out\"] = in_sd[f\"blocks.{layer}.mlp.2.weight\"].T\n",
    "\n",
    "    return out_sd\n",
    "\n",
    "\n",
    "if LOAD_AND_CONVERT_CHECKPOINT:\n",
    "    synthetic_checkpoint = torch.load(\"/workspace/othello_world/gpt_synthetic.ckpt\")\n",
    "    for name, param in synthetic_checkpoint.items():\n",
    "        if name.startswith(\"blocks.0\") or not name.startswith(\"blocks\"):\n",
    "            print(name, param.shape)\n",
    "\n",
    "    cfg = HookedTransformerConfig(\n",
    "        n_layers=8,\n",
    "        d_model=512,\n",
    "        d_head=64,\n",
    "        n_heads=8,\n",
    "        d_mlp=2048,\n",
    "        d_vocab=61,\n",
    "        n_ctx=59,\n",
    "        act_fn=\"gelu\",\n",
    "        normalization_type=\"LNPre\",\n",
    "    )\n",
    "    model = HookedTransformer(cfg)\n",
    "\n",
    "    model.load_and_process_state_dict(\n",
    "        convert_to_transformer_lens_format(synthetic_checkpoint)\n",
    "    )"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing code for the synthetic checkpoint giving the correct outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[21, 41, 40, 34, 40, 41,  3, 11, 21, 43, 40, 21, 28, 50, 33, 50, 33,  5,\n",
       "         33,  5, 52, 46, 14, 46, 14, 47, 38, 57, 36, 50, 38, 15, 28, 26, 28, 59,\n",
       "         50, 28, 14, 28, 28, 28, 28, 45, 28, 35, 15, 14, 30, 59, 49, 59, 15, 15,\n",
       "         14, 15,  8,  7,  8]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# An example input\n",
    "sample_input = torch.tensor(\n",
    "    [\n",
    "        [\n",
    "            20,\n",
    "            19,\n",
    "            18,\n",
    "            10,\n",
    "            2,\n",
    "            1,\n",
    "            27,\n",
    "            3,\n",
    "            41,\n",
    "            42,\n",
    "            34,\n",
    "            12,\n",
    "            4,\n",
    "            40,\n",
    "            11,\n",
    "            29,\n",
    "            43,\n",
    "            13,\n",
    "            48,\n",
    "            56,\n",
    "            33,\n",
    "            39,\n",
    "            22,\n",
    "            44,\n",
    "            24,\n",
    "            5,\n",
    "            46,\n",
    "            6,\n",
    "            32,\n",
    "            36,\n",
    "            51,\n",
    "            58,\n",
    "            52,\n",
    "            60,\n",
    "            21,\n",
    "            53,\n",
    "            26,\n",
    "            31,\n",
    "            37,\n",
    "            9,\n",
    "            25,\n",
    "            38,\n",
    "            23,\n",
    "            50,\n",
    "            45,\n",
    "            17,\n",
    "            47,\n",
    "            28,\n",
    "            35,\n",
    "            30,\n",
    "            54,\n",
    "            16,\n",
    "            59,\n",
    "            49,\n",
    "            57,\n",
    "            14,\n",
    "            15,\n",
    "            55,\n",
    "            7,\n",
    "        ]\n",
    "    ]\n",
    ")\n",
    "# The argmax of the output (ie the most likely next move from each position)\n",
    "sample_output = torch.tensor(\n",
    "    [\n",
    "        [\n",
    "            21,\n",
    "            41,\n",
    "            40,\n",
    "            34,\n",
    "            40,\n",
    "            41,\n",
    "            3,\n",
    "            11,\n",
    "            21,\n",
    "            43,\n",
    "            40,\n",
    "            21,\n",
    "            28,\n",
    "            50,\n",
    "            33,\n",
    "            50,\n",
    "            33,\n",
    "            5,\n",
    "            33,\n",
    "            5,\n",
    "            52,\n",
    "            46,\n",
    "            14,\n",
    "            46,\n",
    "            14,\n",
    "            47,\n",
    "            38,\n",
    "            57,\n",
    "            36,\n",
    "            50,\n",
    "            38,\n",
    "            15,\n",
    "            28,\n",
    "            26,\n",
    "            28,\n",
    "            59,\n",
    "            50,\n",
    "            28,\n",
    "            14,\n",
    "            28,\n",
    "            28,\n",
    "            28,\n",
    "            28,\n",
    "            45,\n",
    "            28,\n",
    "            35,\n",
    "            15,\n",
    "            14,\n",
    "            30,\n",
    "            59,\n",
    "            49,\n",
    "            59,\n",
    "            15,\n",
    "            15,\n",
    "            14,\n",
    "            15,\n",
    "            8,\n",
    "            7,\n",
    "            8,\n",
    "        ]\n",
    "    ]\n",
    ")\n",
    "model(sample_input).argmax(dim=-1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
